pefe-ief Visualization¶
0. Configuration¶
You must guarantee that the values are correct in order to achieve desired results.
In [1]:
# The results directory as produced by pefe-ief library
RESULTS_DIR = "../RESULTS"
1. Utilities¶
Most of the time, you don't need to tinker with this; just proceed to the next section.
a) Loading results¶
In [2]:
from pathlib import Path
import msgpack
import msgpack_numpy
from pprint import pprint
msgpack_numpy.patch()
import numpy as np
import matplotlib.pyplot as plt
def load_results(results_dir):
# type: (str) -> None
INDEX_FILE_PATH = Path(results_dir) / "index.msgpack"
with open(INDEX_FILE_PATH, 'rb') as index_file:
results = msgpack.unpack(index_file, raw=False)
return results
b) Plotting¶
In [3]:
import matplotlib.pyplot as plt
class PLOT_METRICS:
ROC_AUC = 1
ACCURACY = 2
F1 = 4
PRECISION = 8
RECALL = 16
_PLOT_METRICS_ALL = 0
for attr, value in PLOT_METRICS.__dict__.items():
if isinstance(value, int):
_PLOT_METRICS_ALL |= value
PLOT_METRICS.ALL = _PLOT_METRICS_ALL
def plot_metrics_one_threshold(results, metrics, i_threshold):
# type: (list[dict[str, dict[str, int|float|str]]], int, int) -> None
model_names = [
r['MODEL']['name']
for r in results
]
plt.figure(figsize=(12,6))
if metrics & PLOT_METRICS.ROC_AUC:
roc_auc_scores = [r['common_stats']['roc_auc'] for r in results]
plt.plot(model_names, roc_auc_scores, marker='o', label="ROC-AUC")
if metrics & PLOT_METRICS.ACCURACY:
accuracy = [r['stats_per_thresholds'][i_threshold]['accuracy'] for r in results]
plt.plot(model_names, accuracy, marker='o', label="Accuracy")
if metrics & PLOT_METRICS.F1:
f1_scores = [r['stats_per_thresholds'][i_threshold]['f1'] for r in results]
plt.plot(model_names, f1_scores, marker='o', label="F1")
if metrics & PLOT_METRICS.PRECISION:
precision = [r['stats_per_thresholds'][i_threshold]['precision'] for r in results]
plt.plot(model_names, precision, marker='o', label="Precision")
if metrics & PLOT_METRICS.RECALL:
recall = [r['stats_per_thresholds'][i_threshold]['recall'] for r in results]
plt.plot(model_names, recall, marker='o', label="Recall")
plt.xticks(rotation=45, ha='right')
plt.ylabel("Score")
plt.title("PE Malware Detection Models' Performance")
plt.legend(loc='lower right')
plt.tight_layout()
plt.show()
def plot_metrics(results, metrics):
# type: (list[dict[str, dict[str, int|float|str]]], int) -> None
"""
Usage:
plot_metrics(
PLOT_METRICS.AUC
| PLOT_METRICS.ACCURACY
| PLOT_METRICS.F1
... # more if needed
)
"""
import ipywidgets as widgets
from IPython.display import display
for i_threshold in range(len(results[0]['stats_per_thresholds'])):
threshold = results[0]['stats_per_thresholds'][i_threshold]['threshold']
html_label = widgets.HTML(value='<p style="font-size:24px;">Threshold = ' + str(threshold) + '</p>')
display(html_label)
plot_metrics_one_threshold(results, metrics, i_threshold)
c) Displaying pre-rendered curves¶
In [4]:
import base64
from tqdm import tqdm
def to_base64_uri(path):
with open(path, "rb") as f:
data = f.read()
return "data:image/png;base64," + base64.b64encode(data).decode("utf-8")
def display_images_in_groups(image_groups):
import ipywidgets as widgets
from IPython.display import display, HTML
group_outputs = []
group_titles = []
for image_group in tqdm(image_groups):
out = widgets.Output()
with out:
display(HTML('''
<div style="display: flex; flex-direction: column; gap: 5px; justify-content: center; align-items: center">
<div style="display: grid; grid-template-columns: auto auto;">'''
+ "".join('''
<div style="display: flex; flex-direction: column; gap: 5px; justify-content: center; align-items: center">
<img src="{}" />
<p>{}</p>
</div>'''.format(img['url'], img['caption'])
for img in image_group['images']
)
+ '''
</div>
</div>
'''))
group_outputs.append(out)
group_titles.append(image_group['caption'])
accordion = widgets.Accordion(
children=group_outputs,
titles=group_titles,
)
display(accordion)
def display_prerendered_curves(results):
# type: (list[dict[str, dict[str, int|float|str]]]) -> None
import ipyplot
import ipywidgets as widgets
from IPython.display import display
display_images_in_groups([
{
"caption": r["MODEL"]["name"],
"images": [
{
"caption": c["type"],
"url": to_base64_uri(c["plot_path"])
}
for c in r["curves"]
],
} for r in results
])
2. Your Playground¶
a) Loading pre-rendered results¶
Just run this once.
In [5]:
RESULTS = load_results(RESULTS_DIR)
RESULTS
Out[5]:
[{'dataset': {'total_count': 403032,
'malware_count': 215930,
'benign_count': 187102},
'stats_per_thresholds': [{'threshold': 0.5,
'total_hits': 393972,
'total_misses': 9060,
'accuracy': 0.9775203954028464,
'TP': 211615,
'TN': 182357,
'FP': 4745,
'FN': 4315,
'precision': 0.9800166720696522,
'recall': 0.9800166720696522,
'f1': 0.9800166720696522},
{'threshold': 0.6,
'total_hits': 393878,
'total_misses': 9154,
'accuracy': 0.9772871633021696,
'TP': 210716,
'TN': 183162,
'FP': 3940,
'FN': 5214,
'precision': 0.9758532857870607,
'recall': 0.9758532857870607,
'f1': 0.9758532857870607},
{'threshold': 0.7,
'total_hits': 393047,
'total_misses': 9985,
'accuracy': 0.9752252922844836,
'TP': 209210,
'TN': 183837,
'FP': 3265,
'FN': 6720,
'precision': 0.9688788033158894,
'recall': 0.9688788033158894,
'f1': 0.9688788033158894},
{'threshold': 0.8,
'total_hits': 391715,
'total_misses': 11317,
'accuracy': 0.9719203437940411,
'TP': 207295,
'TN': 184420,
'FP': 2682,
'FN': 8635,
'precision': 0.9600101884870097,
'recall': 0.9600101884870097,
'f1': 0.9600101884870097},
{'threshold': 0.85,
'total_hits': 390577,
'total_misses': 12455,
'accuracy': 0.9690967466603149,
'TP': 205752,
'TN': 184825,
'FP': 2277,
'FN': 10178,
'precision': 0.9528643541888575,
'recall': 0.9528643541888575,
'f1': 0.9528643541888575},
{'threshold': 0.9,
'total_hits': 388727,
'total_misses': 14305,
'accuracy': 0.9645065404235892,
'TP': 203354,
'TN': 185373,
'FP': 1729,
'FN': 12576,
'precision': 0.9417589033483074,
'recall': 0.9417589033483074,
'f1': 0.9417589033483074}],
'curves': [{'type': 'ROC',
'plot_path': '/home/lam/DeepMalNet/RESULTS/images/epoch16_1757895844.926428_IEF_ROC.png'},
{'type': 'DET',
'plot_path': '/home/lam/DeepMalNet/RESULTS/images/epoch16_1757895844.926428_IEF_DET.png'},
{'type': 'Actual Positives',
'plot_path': '/home/lam/DeepMalNet/RESULTS/images/epoch16_1757895844.926428_IEF_TPR_FNR_per_threshold_aka_Actual_Positives.png'},
{'type': 'Actual Negatives',
'plot_path': '/home/lam/DeepMalNet/RESULTS/images/epoch16_1757895844.926428_IEF_TNR_FFR_per_threshold_aka_Actual_Negatives.png'}],
'common_stats': {'roc_auc': 0.994884938053634},
'MODEL': {'type': 'InferenceDeepMalNetModel',
'path': '/mnt/scsi/lam/DeepMalNet-artifacts/epoch16_1757895844.926428.pth',
'name': 'epoch16_1757895844.926428'}}]
b) Plotting general metrics¶
In [6]:
plot_metrics(RESULTS, PLOT_METRICS.ALL)
In [7]:
plot_metrics(RESULTS, PLOT_METRICS.ACCURACY | PLOT_METRICS.PRECISION | PLOT_METRICS.RECALL)
In [8]:
plot_metrics(RESULTS, PLOT_METRICS.F1 | PLOT_METRICS.ROC_AUC)
c) Displaying pre-rendered ROC, DET curves¶
In [9]:
display_prerendered_curves(RESULTS)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.94it/s]